Explaining Text Classification

from explainer.explainers import feature_attributions_explainer, metrics_explainer
import warnings
warnings.filterwarnings('ignore')
import os
os.environ['KMP_WARNINGS'] = 'off'

import numpy as np
from sklearn import datasets

all_categories = ['alt.atheism','comp.graphics','comp.os.ms-windows.misc','comp.sys.ibm.pc.hardware',
                  'comp.sys.mac.hardware','comp.windows.x', 'misc.forsale','rec.autos','rec.motorcycles',
                  'rec.sport.baseball','rec.sport.hockey','sci.crypt','sci.electronics','sci.med',
                  'sci.space','soc.religion.christian','talk.politics.guns','talk.politics.mideast',
                  'talk.politics.misc','talk.religion.misc']

selected_categories = ['alt.atheism','comp.graphics','rec.motorcycles','sci.space','talk.politics.misc']

X_train_text, Y_train = datasets.fetch_20newsgroups(subset="train", categories=selected_categories, return_X_y=True)
X_test_text , Y_test  = datasets.fetch_20newsgroups(subset="test", categories=selected_categories, return_X_y=True)

X_train_text = np.array(X_train_text)
X_test_text = np.array(X_test_text)

classes = np.unique(Y_train)
mapping = dict(zip(classes, selected_categories))

len(X_train_text), len(X_test_text), classes, mapping
(2720,
 1810,
 array([0, 1, 2, 3, 4]),
 {0: 'alt.atheism',
  1: 'comp.graphics',
  2: 'rec.motorcycles',
  3: 'sci.space',
  4: 'talk.politics.misc'})
print(Y_test)
[2 3 0 ... 3 2 3]

Vectorize Text Data

import sklearn
import numpy as np
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer

vectorizer = TfidfVectorizer(max_features=50000)

vectorizer.fit(np.concatenate((X_train_text, X_test_text)))
X_train = vectorizer.transform(X_train_text)
X_test = vectorizer.transform(X_test_text)

X_train, X_test = X_train.toarray(), X_test.toarray()

X_train.shape, X_test.shape
((2720, 50000), (1810, 50000))

Define the Model

from tensorflow.keras.models import Sequential
from tensorflow.keras import layers

def create_model():
    return Sequential([
                        layers.Input(shape=X_train.shape[1:]),
                        layers.Dense(128, activation="relu"),
                        layers.Dense(64, activation="relu"),
                        layers.Dense(len(classes), activation="softmax"),
                    ])

model = create_model()
model.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense (Dense)               (None, 128)               6400128   
                                                                 
 dense_1 (Dense)             (None, 64)                8256      
                                                                 
 dense_2 (Dense)             (None, 5)                 325       
                                                                 
=================================================================
Total params: 6,408,709
Trainable params: 6,408,709
Non-trainable params: 0
_________________________________________________________________

Compile and Train Model

model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"])
history = model.fit(X_train, Y_train, batch_size=256, epochs=5, validation_data=(X_test, Y_test))

Evaluate Model Performance

from sklearn.metrics import accuracy_score, classification_report

train_preds = model.predict(X_train)
test_preds = model.predict(X_test)

print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, np.argmax(train_preds, axis=1))))
print("Test  Accuracy : {:.3f}".format(accuracy_score(Y_test, np.argmax(test_preds, axis=1))))
print("\nClassification Report : ")
print(classification_report(Y_test, np.argmax(test_preds, axis=1), target_names=selected_categories))
Hide code cell output
 1/85 [..............................] - ETA: 10s

 7/85 [=>............................] - ETA: 0s 

19/85 [=====>........................] - ETA: 0s

32/85 [==========>...................] - ETA: 0s

41/85 [=============>................] - ETA: 0s

51/85 [=================>............] - ETA: 0s

63/85 [=====================>........] - ETA: 0s

76/85 [=========================>....] - ETA: 0s

85/85 [==============================] - 1s 5ms/step
 1/57 [..............................] - ETA: 5s

11/57 [====>.........................] - ETA: 0s

21/57 [==========>...................] - ETA: 0s

32/57 [===============>..............] - ETA: 0s

44/57 [======================>.......] - ETA: 0s

56/57 [============================>.] - ETA: 0s

57/57 [==============================] - 0s 5ms/step
Train Accuracy : 1.000
Test  Accuracy : 0.949

Classification Report : 
                    precision    recall  f1-score   support

       alt.atheism       0.98      0.91      0.94       319
     comp.graphics       0.94      0.96      0.95       389
   rec.motorcycles       0.98      0.98      0.98       398
         sci.space       0.93      0.94      0.93       394
talk.politics.misc       0.92      0.95      0.93       310

          accuracy                           0.95      1810
         macro avg       0.95      0.95      0.95      1810
      weighted avg       0.95      0.95      0.95      1810
# one-hot-encode clasess
oh_Y_test = np.eye(len(classes))[Y_test]

cm = metrics_explainer['confusionmatrix'](oh_Y_test, test_preds, selected_categories)
cm.visualize()
print(cm.report)
                    precision    recall  f1-score   support

       alt.atheism       0.98      0.91      0.94       319
     comp.graphics       0.94      0.96      0.95       389
   rec.motorcycles       0.98      0.98      0.98       398
         sci.space       0.93      0.94      0.93       394
talk.politics.misc       0.92      0.95      0.93       310

          accuracy                           0.95      1810
         macro avg       0.95      0.95      0.95      1810
      weighted avg       0.95      0.95      0.95      1810
../../_images/6868607e8e23df4b889c3eb3eec6dbe0b7d2e767f35ec9eb0028e5fbd550493c.png
plotter = metrics_explainer['plot'](oh_Y_test, test_preds, selected_categories)
plotter.pr_curve()
plotter.roc_curve()
import re

X_batch_text = X_test_text[1:3]
X_batch = X_test[1:3]

print("Samples : ")
for text in X_batch_text:
    print(re.split(r"\W+", text))
    print()

preds_proba = model.predict(X_batch)
preds = preds_proba.argmax(axis=1)

print("Actual    Target Values : {}".format([selected_categories[target] for target in Y_test[1:3]]))
print("Predicted Target Values : {}".format([selected_categories[target] for target in preds]))
print("Predicted Probabilities : {}".format(preds_proba.max(axis=1)))
Samples : 
['From', 'prb', 'access', 'digex', 'net', 'Pat', 'Subject', 'Re', 'Near', 'Miss', 'Asteroids', 'Q', 'Organization', 'Express', 'Access', 'Online', 'Communications', 'Greenbelt', 'MD', 'USA', 'Lines', '4', 'Distribution', 'sci', 'NNTP', 'Posting', 'Host', 'access', 'digex', 'net', 'TRry', 'the', 'SKywatch', 'project', 'in', 'Arizona', 'pat', '']

['From', 'cobb', 'alexia', 'lis', 'uiuc', 'edu', 'Mike', 'Cobb', 'Subject', 'Science', 'and', 'theories', 'Organization', 'University', 'of', 'Illinois', 'at', 'Urbana', 'Lines', '19', 'As', 'per', 'various', 'threads', 'on', 'science', 'and', 'creationism', 'I', 've', 'started', 'dabbling', 'into', 'a', 'book', 'called', 'Christianity', 'and', 'the', 'Nature', 'of', 'Science', 'by', 'JP', 'Moreland', 'A', 'question', 'that', 'I', 'had', 'come', 'from', 'one', 'of', 'his', 'comments', 'He', 'stated', 'that', 'God', 'is', 'not', 'necessarily', 'a', 'religious', 'term', 'but', 'could', 'be', 'used', 'as', 'other', 'scientific', 'terms', 'that', 'give', 'explanation', 'for', 'events', 'or', 'theories', 'without', 'being', 'a', 'proven', 'scientific', 'fact', 'I', 'think', 'I', 'got', 'his', 'point', 'I', 'can', 'quote', 'the', 'section', 'if', 'I', 'm', 'being', 'vague', 'The', 'examples', 'he', 'gave', 'were', 'quarks', 'and', 'continental', 'plates', 'Are', 'there', 'explanations', 'of', 'science', 'or', 'parts', 'of', 'theories', 'that', 'are', 'not', 'measurable', 'in', 'and', 'of', 'themselves', 'or', 'can', 'everything', 'be', 'quantified', 'measured', 'tested', 'etc', 'MAC', 'Michael', 'A', 'Cobb', 'and', 'I', 'won', 't', 'raise', 'taxes', 'on', 'the', 'middle', 'University', 'of', 'Illinois', 'class', 'to', 'pay', 'for', 'my', 'programs', 'Champaign', 'Urbana', 'Bill', 'Clinton', '3rd', 'Debate', 'cobb', 'alexia', 'lis', 'uiuc', 'edu', 'Nobody', 'can', 'explain', 'everything', 'to', 'anybody', 'G', 'K', 'Chesterton', '']


1/1 [==============================] - ETA: 0s

1/1 [==============================] - 0s 38ms/step
Actual    Target Values : ['sci.space', 'alt.atheism']
Predicted Target Values : ['sci.space', 'alt.atheism']
Predicted Probabilities : [0.9238798  0.75361186]

SHAP Partition Explainer

Visualize SHAP Values Correct Predictions

def make_predictions(X_batch_text):
    X_batch = vectorizer.transform(X_batch_text).toarray()
    preds = model.predict(X_batch)
    return preds

partition_explainer = feature_attributions_explainer.partitionexplainer(make_predictions, r"\W+", selected_categories)(X_batch_text)

Text Plot

partition_explainer.visualize()


[0]
outputs
alt.atheism
comp.graphics
rec.motorcycles
sci.space
talk.politics.misc


0.50.30.10.70.90.1462830.146283base value0.0121570.012157falt.atheism(inputs)0.017 Arizona. 0.007 TRry 0.006 SKywatch 0.004 Miss 0.001 Re: -0.015 project -0.011 pat -0.009 digex. -0.008 access. -0.008 Pat) -0.008 sci -0.008 net -0.008 Online -0.008 prb@ -0.007 net ( -0.007 Access -0.007 access. -0.006 Express -0.006 digex. -0.006 Greenbelt, -0.005 Communications, -0.004 Asteroids ( -0.004 Near -0.004 Distribution: -0.004 USA -0.004 Subject: -0.003 Organization: -0.003 MD -0.003 From: -0.003 NNTP- -0.002 Posting- -0.002 Lines: -0.002 Host: -0.001 4 -0.0 in -0.0 the -0.0 Q)
inputs
-0.003
From:
-0.008
prb@
-0.007
access.
-0.006
digex.
-0.007
net (
-0.008
Pat)
-0.004
Subject:
0.001
Re:
-0.004
Near
0.004
Miss
-0.004
Asteroids (
-0.0
Q)
-0.003
Organization:
-0.006
Express
-0.007
Access
-0.008
Online
-0.005
Communications,
-0.006
Greenbelt,
-0.003
MD
-0.004
USA
-0.002
Lines:
-0.001
4
-0.004
Distribution:
-0.008
sci
-0.003
NNTP-
-0.002
Posting-
-0.002
Host:
-0.008
access.
-0.009
digex.
-0.008
net
0.007
TRry
-0.0
the
0.006
SKywatch
-0.015
project
-0.0
in
0.017
Arizona.
-0.011
pat


[1]
outputs
alt.atheism
comp.graphics
rec.motorcycles
sci.space
talk.politics.misc


0.50.30.10.70.90.1462830.146283base value0.7536120.753612falt.atheism(inputs)0.075 alexia.lis.uiuc. 0.068 Debate cobb@ 0.053 lis. 0.052 alexia. 0.042 edu Nobody can explain 0.036 necessarily a religious term, 0.035 that God 0.032 his comments. He stated 0.031 is not 0.027 cobb@ 0.026 but 0.023 proven scientific fact. I 0.022 point -- I can quote 0.022 think I got his 0.02 dabbling into a book called Christianity and 0.02 theories, without being a 0.02 give explanation for events 0.019 creationism, I' 0.018 Cobb "...and I won' 0.018 Mike Cobb) 0.017 From: 0.016 other scientific terms that 0.014 or 0.013 ve started 0.013 t raise taxes on 0.011 could be used as 0.011 examples he 0.009 question that I had come from one of 0.008 the Nature of Science 0.008 m being vague. The 0.006 gave were quarks 0.006 Clinton 3rd 0.005 Urbana -Bill 0.005 uiuc.edu ( 0.005 are not measurable in 0.005 the section if I' 0.005 Subject: Science 0.004 for my programs." Champaign- 0.004 and -0.027 measured, tested, etc.? -0.026 everything to anybody. G.K.Chesterton -0.016 per various -0.015 University of -0.015 Illinois at -0.014 19 As -0.014 Urbana Lines: -0.014 MAC -- **************************************************************** Michael -0.014 theories Organization: -0.013 A. -0.01 Illinois class to pay -0.009 and of themselves, or can everything be quantified, -0.009 and continental plates. Are there -0.006 by JP Moreland. A -0.006 and -0.006 threads on science -0.001 explanations of science or -0.001 the middle University of -0.001 parts of theories that
inputs
0.017
From:
0.027
cobb@
0.052
alexia.
0.053
lis.
0.005 / 2
uiuc.edu (
0.018 / 2
Mike Cobb)
0.005 / 2
Subject: Science
0.004
and
-0.014 / 2
theories Organization:
-0.015 / 2
University of
-0.015 / 2
Illinois at
-0.014 / 2
Urbana Lines:
-0.014 / 2
19 As
-0.016 / 2
per various
-0.006 / 3
threads on science
-0.006
and
0.019 / 2
creationism, I'
0.013 / 2
ve started
0.02 / 7
dabbling into a book called Christianity and
0.008 / 4
the Nature of Science
-0.006 / 4
by JP Moreland. A
0.009 / 8
question that I had come from one of
0.032 / 4
his comments. He stated
0.035 / 2
that God
0.031 / 2
is not
0.036 / 4
necessarily a religious term,
0.026
but
0.011 / 4
could be used as
0.016 / 4
other scientific terms that
0.02 / 4
give explanation for events
0.014
or
0.02 / 4
theories, without being a
0.023 / 4
proven scientific fact. I
0.022 / 4
think I got his
0.022 / 4
point -- I can quote
0.005 / 4
the section if I'
0.008 / 4
m being vague. The
0.011 / 2
examples he
0.006 / 3
gave were quarks
-0.009 / 5
and continental plates. Are there
-0.001 / 4
explanations of science or
-0.001 / 4
parts of theories that
0.005 / 4
are not measurable in
-0.009 / 8
and of themselves, or can everything be quantified,
-0.027 / 3
measured, tested, etc.?
-0.014 / 2
MAC -- **************************************************************** Michael
-0.013
A.
0.018 / 4
Cobb "...and I won'
0.013 / 4
t raise taxes on
-0.001 / 4
the middle University of
-0.01 / 4
Illinois class to pay
0.004 / 4
for my programs." Champaign-
0.005 / 2
Urbana -Bill
0.006 / 2
Clinton 3rd
0.068 / 2
Debate cobb@
0.075 / 3
alexia.lis.uiuc.
0.042 / 4
edu Nobody can explain
-0.026 / 6
everything to anybody. G.K.Chesterton

Bar Plots

Bar Plot 1

shap = partition_explainer.shap
shap_values = partition_explainer.shap_values

shap.plots.bar(partition_explainer.shap_values[:,:, selected_categories[preds[0]]].mean(axis=0), max_display=15,
               order=shap.Explanation.argsort.flip)
../../_images/5099283415b9bc9dfa56d06514596007986429a9c649c0e4d163eb30219f7df6.png

Bar Plot 2

shap.plots.bar(shap_values[0,:, selected_categories[preds[0]]], max_display=15,
               order=shap.Explanation.argsort.flip)
../../_images/71173b6d5160275967efccef2537775770bee67fc40ae58104af6db2a0f4e98a.png

Bar Plot 3

shap.plots.bar(shap_values[:,:, selected_categories[preds[1]]].mean(axis=0), max_display=15,
               order=shap.Explanation.argsort.flip)
../../_images/a3fb331ebb3f524e3615a680e811d43ea47458da27967599432ff46e2c9cc2a9.png

Bar Plot 4

shap.plots.bar(shap_values[1,:, selected_categories[preds[1]]], max_display=15,
               order=shap.Explanation.argsort.flip)
../../_images/54806bebc380c06a9fb05261657b5b10dad96b193eb8c26a93fde324798a3019.png

Waterfall Plots

Waterfall Plot 1

shap.waterfall_plot(shap_values[0][:, selected_categories[preds[0]]], max_display=15)
../../_images/9f7cc6dd83585f1a641f7b6162bd70f2b8f0cd694828d852cec3f8f632839610.png

Waterfall Plot 2

shap.waterfall_plot(shap_values[1][:, selected_categories[preds[1]]], max_display=15)
../../_images/0b357c308a2529c8963f20ca5ab85c4ef268e08dc796536909003cffe6270a21.png

Force Plot

import re
tokens = re.split("\W+", X_batch_text[0].lower())
shap.initjs()
shap.force_plot(shap_values.base_values[0][preds[0]], shap_values[0][:, preds[0]].values,
                feature_names = tokens[:-1], out_names=selected_categories[preds[0]])
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.